{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f3949bb5-689d-4f84-a247-92bb06b831b3",
   "metadata": {},
   "source": [
    "<center><img src=\"https://imgcdn.stablediffusionweb.com/2025/1/3/355bcb46-cdbe-4632-9e49-ba0b79fd662e.jpg\"/></center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16ebae21-1be6-42ee-ab7c-a197a29c5ad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import chess\n",
    "import chess.svg\n",
    "from stockfish import Stockfish\n",
    "import re\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5693c761-ea13-46e8-b664-39ed96f84f70",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import SVG, display, clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8be511c1-4881-4912-ab30-234bfa4a83d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ollama import chat\n",
    "from ollama import ChatResponse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ad8a9a6-b2b9-4be7-947d-55fdcc2242a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pgn_of_moves(move_list):\n",
    "    s = \"\"\n",
    "    for i, (a,b) in enumerate(move_list):\n",
    "        s += f\"{i+1}.{a} {b} \"\n",
    "    return s\n",
    "\n",
    "def moves_of_pgn(opening):\n",
    "    # Split the opening string by spaces\n",
    "    tokens = opening.split()\n",
    "    \n",
    "    # Initialize an empty list to store moves\n",
    "    list_moves = []\n",
    "\n",
    "    # Iterate through tokens, skipping the numeric labels\n",
    "    i = 0\n",
    "    while i < len(tokens):\n",
    "        if tokens[i].endswith('.'):  # Skip move numbers (e.g., '1.', '2.')\n",
    "            i += 1\n",
    "        else:\n",
    "            # Remove numeric prefix if present and add pairs of moves (white and black)\n",
    "            white_move = tokens[i].split('.')[-1]\n",
    "            black_move = tokens[i + 1] if i + 1 < len(tokens) and not tokens[i + 1].endswith('.') else None\n",
    "            list_moves.append((white_move, black_move))\n",
    "            i += 2 if black_move else 1\n",
    "\n",
    "    return list_moves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72bb9e63-37f4-474a-8a57-76a21998294a",
   "metadata": {},
   "outputs": [],
   "source": [
    "opening='1.e4 e5 2.Nf3 Nc6 3.Bb5 a6 4.Ba4 Nf6'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e5c673-0786-4d3c-b645-0b8a82bbae57",
   "metadata": {},
   "outputs": [],
   "source": [
    "moves_of_pgn(opening)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab9326f4-4fdf-49fe-a118-d04bcd654e1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pgn_of_moves(moves_of_pgn(opening))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8991967e-bbd8-47ea-a344-6d1202e781f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_next_move(input_string):\n",
    "    \"\"\"\n",
    "    Extracts content between <next_move> and </next_move> tags from the input string.\n",
    "\n",
    "    Args:\n",
    "        input_string (str): The string containing the tags and content to parse.\n",
    "\n",
    "    Returns:\n",
    "        str: The content between the <next_move> tags, or None if the tags are not found.\n",
    "    \"\"\"\n",
    "    match = re.search(r\"<next_move>(.*?)</next_move>\", input_string, re.DOTALL)\n",
    "    if match:\n",
    "        return match.group(1).strip()\n",
    "    return None\n",
    "\n",
    "def parse_move(input_string):\n",
    "    \"\"\"\n",
    "    Checks if the input string starts with an integer followed by a dot and returns the integer and the rest of the string.\n",
    "\n",
    "    Args:\n",
    "        input_string (str): The string to test and parse.\n",
    "\n",
    "    Returns:\n",
    "        tuple: A tuple containing the integer and the string after the dot, or None if the pattern does not match.\n",
    "    \"\"\"\n",
    "    match = re.match(r\"^(\\d+)\\.(.*)$\", input_string)\n",
    "    if match:\n",
    "        integer_part = int(match.group(1))\n",
    "        rest_of_string = match.group(2).strip()\n",
    "        return integer_part, rest_of_string\n",
    "    return None\n",
    "    \n",
    "prompt = \"\"\"Here is a start of a chess game using Portable Game Notation (PGN):\n",
    "{pgn_moves}\n",
    "What would you play next? Explain your thought process and give your recommendation. Your recommendation should be written in PGN and contain only the last move.\n",
    "Please format your answer as follows:\n",
    "<explanation>[your thought process]</explanation>\n",
    "<next_move>[{i}.your move in PGN]</next_move>\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "216c70f1-079d-4b67-8174-3b8d97aa3967",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Game:\n",
    "    def __init__(self, llm, elo=1200):\n",
    "        self.board = chess.Board()\n",
    "        self.elo = elo\n",
    "        self.stockfish = Stockfish(\n",
    "        path=\"/opt/homebrew/bin/stockfish/\",\n",
    "        depth=3,\n",
    "        parameters={\n",
    "            \"Threads\": 2,\n",
    "            \"Minimum Thinking Time\": 30,\n",
    "            \"UCI_Elo\": self.elo,\n",
    "            },\n",
    "        )\n",
    "        self.llm = llm\n",
    "        self.not_finish = True\n",
    "        self.logs = []  # Store print messages\n",
    "\n",
    "    def log(self, message):\n",
    "        \"\"\"Stores a message to the log and prints it.\"\"\"\n",
    "        self.logs.append(message)\n",
    "\n",
    "    def display_logs_and_board(self, size=400):\n",
    "        \"\"\"Displays the logs and current board in Jupyter Notebook.\"\"\"\n",
    "        clear_output(wait=True)  # Clear previous output\n",
    "        \n",
    "        # Display the board\n",
    "        board_svg = chess.svg.board(self.board, size=size)\n",
    "        display(SVG(board_svg))\n",
    "        # Display logs\n",
    "        for log in self.logs:\n",
    "            print(log)\n",
    "\n",
    "    def update_moves(self, move_w, move_b):\n",
    "        for move in (move_w, move_b):\n",
    "            board_move = self.board.parse_san(move)\n",
    "            self.board.push(board_move)\n",
    "            self.stockfish.make_moves_from_current_position([ board_move ])\n",
    "\n",
    "    def start_game(self, opening):\n",
    "        self.move_list = moves_of_pgn(opening)\n",
    "        for (move_w, move_b) in self.move_list:\n",
    "            self.update_moves(move_w,move_b)\n",
    "\n",
    "    def nxt_move(self):\n",
    "        nxt_move_llm = self.llm.nxt_move_llm(self.move_list)\n",
    "        if nxt_move_llm is None:\n",
    "            print('Parsing error')\n",
    "            check = False\n",
    "            self.not_finish = False\n",
    "        else:\n",
    "            try:\n",
    "                move_san = nxt_move_llm[1]\n",
    "                move_llm = self.board.parse_san(move_san)\n",
    "                check = True\n",
    "            except Exception as err:\n",
    "                self.log(f\"Illegal move {err=}, {type(err)=}\")\n",
    "                check = False\n",
    "                self.not_finish = False # stop after first lilegal move\n",
    "        \n",
    "        if not check:\n",
    "            move_llm = random.choice(list(self.board.legal_moves))\n",
    "            move_san = self.board.san(move_llm)\n",
    "        self.board.push(move_llm)\n",
    "        self.stockfish.make_moves_from_current_position([ move_llm ])\n",
    "        self.log(f\"LLM Move: {move_san}\")\n",
    "        self.display_logs_and_board()  # Display the board after LLM's move\n",
    "        if self.board.is_game_over():\n",
    "            print(\"GAME-OVER\")\n",
    "            print(self.board.outcome())\n",
    "            self.not_finish = False\n",
    "        else:\n",
    "            best = self.stockfish.get_best_move()\n",
    "            self.stockfish.make_moves_from_current_position([ best ])\n",
    "            move = chess.Move.from_uci(best)\n",
    "            st_san = self.board.san(move)\n",
    "            self.board.push(move)\n",
    "            self.move_list.append((move_san, st_san))\n",
    "            self.log(f\"Stockfish Move: {st_san}\")\n",
    "            self.display_logs_and_board()  # Display the board after Stockfish's move\n",
    "            if self.board.is_game_over():\n",
    "                print(\"GAME-OVER\")\n",
    "                print(self.board.outcome())\n",
    "                self.not_finish = False\n",
    "        \n",
    "    def play_game(self, opening):\n",
    "        self.start_game(opening)\n",
    "        self.display_logs_and_board()  # Display the initial logs and board\n",
    "        while self.not_finish:\n",
    "            self.nxt_move()\n",
    "    \n",
    "\n",
    "class LLM:\n",
    "    def __init__(self, model='llama3.2', prompt=prompt):\n",
    "        self.model = model\n",
    "        self.prompt = prompt\n",
    "        \n",
    "\n",
    "    def make_prompt(self, move_list):\n",
    "        pgn_moves = pgn_of_moves(move_list)\n",
    "        prompt_chess = self.prompt.format(pgn_moves=pgn_moves, i=len(move_list)+1)\n",
    "        return prompt_chess\n",
    "\n",
    "    def nxt_move_llm(self, move_list, verbose = True):\n",
    "        prompt_chess = self.make_prompt(move_list)\n",
    "        response: ChatResponse = chat(model=self.model, messages=[\n",
    "          {\n",
    "            'role': 'user',\n",
    "            'content': prompt_chess,\n",
    "          },\n",
    "        ])\n",
    "        if verbose:\n",
    "            print(response.message.content)\n",
    "        next_move = extract_next_move(response.message.content)\n",
    "        if next_move:\n",
    "            return parse_move(next_move)\n",
    "        else:\n",
    "            return None\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75d0afac-6413-4b55-b85f-727c672e5d43",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = LLM('mistral')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81031d83-af62-48b8-96cf-5d8c7903f045",
   "metadata": {},
   "outputs": [],
   "source": [
    "game = Game(llm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5a87555-a129-4bf2-8b34-ef678cfd75c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "game.play_game('')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d12a836-5b56-4870-99cb-3f02e7fb3747",
   "metadata": {},
   "source": [
    "# Playing with nanogpt\n",
    "\n",
    "![](https://github.com/karpathy/nanoGPT/raw/master/assets/nanogpt.jpg)\n",
    "\n",
    "Following [Adam Karvonen](https://adamkarvonen.github.io/machine_learning/2024/01/03/chess-world-models.html), I trained a [nanoGPT](https://github.com/karpathy/nanoGPT/tree/master) model in order to play chess.\n",
    "\n",
    "You can get it by following [these steps](https://github.com/dataflowr/notebooks/blob/master/llm/02_get_model.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c2fefb-870c-44b1-98fb-da0f7f4ffe47",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from nanogpt.model_inf import GPTConfig, GPT\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bd12110-a885-4f7e-a1c3-41b1861819ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "meta_path = \"nanogpt/meta.pkl\" # do not change if you followed the instructions in the 02_get_model\n",
    "\n",
    "with open(meta_path, \"rb\") as f:\n",
    "    meta = pickle.load(f)\n",
    "    stoi, itos = meta[\"stoi\"], meta[\"itos\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "817cee11-ed7b-45b6-90fb-23c2a4f27445",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = #torch.device('cpu') or  'cuda' or 'mps' \n",
    "\n",
    "checkpoint_path = # where you saved your model\n",
    "checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd56033-dbb0-4b8d-8f59-90cb4b6893fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "gptconf = GPTConfig(**checkpoint[\"model_args\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7b3a61b-7c77-419b-8776-2e2ed37a4dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = GPT(gptconf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33ab0fbb-cd78-4ed1-9d6e-abb838087043",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_dict = checkpoint[\"model\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3932820-196d-4def-81a9-fa20589365e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "unwanted_prefix = \"_orig_mod.\"\n",
    "for k, v in list(state_dict.items()):\n",
    "    if k.startswith(unwanted_prefix):\n",
    "        #print(k)\n",
    "        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c58867ac-caaa-44ca-8e9d-850cc87206a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdf8b0ee-d7c7-44ba-aee8-3f5ed58656f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95691486-cbc0-4f0f-9c80-a115bc84040b",
   "metadata": {},
   "outputs": [],
   "source": [
    "encode = lambda s: [stoi[c] for c in s]\n",
    "decode = lambda l: \"\".join([itos[i] for i in l])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad169a60-8257-462c-af96-24c4e24a670a",
   "metadata": {},
   "outputs": [],
   "source": [
    "game_start = ';1.e4 e5 '"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f01461e-4944-4eb1-b11e-5d1f014c4ea2",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_ids = encode(game_start)\n",
    "start_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8005130-3410-4dea-b841-aa7199ca4be5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]\n",
    "top_k = None #200  # retain only the top_k most likely tokens, clamp others to have 0 probability\n",
    "max_new_tokens = 400\n",
    "temperature = (\n",
    "            0.01  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions\n",
    "        )\n",
    "with torch.no_grad():\n",
    "    y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)\n",
    "\n",
    "model_response = decode(y[0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884aa1d9-3c22-4e5f-99c0-c4f26f7d9205",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "714541d8-a4c7-4304-90b0-c99e85e13cd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "y[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c797f9d-ef08-437f-855a-0881a8fb14fc",
   "metadata": {},
   "source": [
    "Now define a class `GPT_chess` in order to play games with the model you downloaded.\n",
    "\n",
    "Improve it. But, you are not allowed to use the stockfish engine inside your llm! You can check that moves are valid..."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chess",
   "language": "python",
   "name": "chess"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
